# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convert coco data to mindrecord format."""

import os

import numpy as np
from mindspore.mindrecord import FileWriter
from pycocotools.coco import COCO

from utils.class_factory import ClassFactory, ModuleType




def _has_only_empty_bbox(anno):
    return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)


def _count_visible_keypoints(anno):
    return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)


def has_valid_annotation(anno):
    """Check annotation file."""
    # if it's empty, there is no annotation
    min_keypoints_per_image = 10
    if not anno:
        return False
    # if all boxes have close to zero area, there is no annotation
    if _has_only_empty_bbox(anno):
        return False
    # keypoints task have a slight different criteria for considering
    # if an annotation is valid
    if "keypoints" not in anno[0]:
        return True
    # for keypoint detection tasks, only consider valid images those
    # containing at least min_keypoints_per_image
    if _count_visible_keypoints(anno) >= min_keypoints_per_image:
        return True
    return False


@ClassFactory.register(ModuleType.DATASET)
class Coco2MindRecord:
    """
    convert coco dataset to mind record file.

    Args:
        data_root (str) : The image dir.
        ann_file (str) : The ann file dir.
        mindrecord_file (str) : The output mindrecord file
    Examples:
    """

    COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
                    'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
                    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
                    'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
                    'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
                    'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
                    'kite', 'baseball bat', 'baseball glove', 'skateboard',
                    'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
                    'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
                    'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
                    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
                    'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
                    'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
                    'refrigerator', 'book', 'clock', 'vase', 'scissors',
                    'teddy bear', 'hair drier', 'toothbrush')
    def __init__(self, root, ann_file, mindrecord_dir,
                 remove_images_without_annos=True,
                 filter_crowd_anno=True):
        """Constructor for Coco2MindRecord"""
        self.data_root = root
        self.ann_file = ann_file
        self.mindrecord_dir = mindrecord_dir
        self.remove_images_without_annos = remove_images_without_annos
        self.filter_crowd_anno = filter_crowd_anno

    def create_coco_label(self):
        """Get image path and annotation from COCO."""
        coco = COCO(self.ann_file)

        classes_dict = {}
        cat_ids = coco.loadCats(coco.getCatIds())
        for cat in cat_ids:
            classes_dict[cat["id"]] = cat["name"]

        # Classes need to train or test.
        train_cls = self.COCO_CLASSES
        train_cls_dict = {}
        if train_cls is not None:
            for i, cls in enumerate(train_cls):
                train_cls_dict[cls] = i
        else:
            for cat in cat_ids:
                train_cls_dict[cat["name"]] = cat["id"]

        image_ids = coco.getImgIds()
        image_files = []
        image_anno_dict = {}

        for img_id in image_ids:
            image_info = coco.loadImgs(img_id)
            file_name = image_info[0]["file_name"]
            anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
            anno = coco.loadAnns(anno_ids)

            # filter empty annotations
            if self.remove_images_without_annos and not has_valid_annotation(anno):
                continue

            image_path = os.path.join(self.data_root, file_name)
            annos = []
            for label in anno:
                bbox = label["bbox"]
                class_name = classes_dict[label["category_id"]]

                # filter crowd
                if self.filter_crowd_anno and label["iscrowd"] == 1:
                    continue

                if train_cls is not None and class_name not in train_cls:
                    continue
                x1, x2 = bbox[0], bbox[0] + bbox[2]
                y1, y2 = bbox[1], bbox[1] + bbox[3]
                annos.append(
                    [x1, y1, x2, y2] + [train_cls_dict[class_name]] + [
                        int(label["iscrowd"])])

            image_files.append(image_path)
            if annos:
                image_anno_dict[image_path] = np.array(annos)
            else:
                image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])

        return image_files, image_anno_dict

    def data_to_mindrecord_byte_image(self,
                                      prefix="coco.mindrecord",
                                      file_num=8):
        """Create MindRecord file."""
        mindrecord_path = os.path.join(self.mindrecord_dir, prefix)
        writer = FileWriter(mindrecord_path, file_num)
        image_files, image_anno_dict = self.create_coco_label()

        fasterrcnn_json = {
            "image": {"type": "bytes"},
            "annotation": {"type": "int32", "shape": [-1, 6]},
        }
        writer.add_schema(fasterrcnn_json, "fasterrcnn_json")

        for image_name in image_files:
            with open(image_name, 'rb') as f:
                img = f.read()
            annos = np.array(image_anno_dict[image_name], dtype=np.int32)
            row = {"image": img, "annotation": annos}
            writer.write_raw_data([row])
        writer.commit()

    def __call__(self, prefix='coco.mindrecord'):
        """ Write mindrecord file """
        mindrecord_path = os.path.join(self.mindrecord_dir, prefix)
        if os.path.exists(mindrecord_path + "0"):
            return
        self.data_to_mindrecord_byte_image()
